Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔁 🦈 Support iterative GRPO #2700

Merged
merged 15 commits into from
Feb 4, 2025
Merged

Conversation

shirinyamani
Copy link
Contributor

@shirinyamani shirinyamani commented Jan 30, 2025

What does this PR do?

Following the thread of this issue#2684 and based on Deepseek paper we came to conclude that we need to add a feature which every once in a while (ref_model_sync_steps) can iteratively update the reference model.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

@qgallouedec
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Co-authored-by: Quentin Gallouédec <[email protected]>
@qgallouedec
Copy link
Member

Nice! Can you try locally with Multi GPU / DeepSpeed ZeRO 1/2/3? If you don't have the hardware, I can do it.

@qgallouedec
Copy link
Member

In the DeepSeek-R1 paper, I think they sync the ref after each epoch, no?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@shirinyamani
Copy link
Contributor Author

shirinyamani commented Jan 30, 2025

@qgallouedec
I do not have access to multi-gpu atm unfortunately!
I can request access but it might take long time for them to assign gpu to me!

for the update, I thiiiink they do the update after one complete iteration (epoc), but I am not sure because I think this way there might be a conflict, because the default ref_model_sync_steps is 64 , meaning the update of the ref_model will happen after these many steps, but one epoc will be alot more than this probably! (i.e. for one epoc scenario we gotta set the ref_model_sync_steps as the steps it takes for entire epoc)

Maybe I am misunderstanding?

Screenshot 2025-01-30 at 10 38 58 AM

@shirinyamani
Copy link
Contributor Author

shirinyamani commented Jan 30, 2025

Note that this algorithm and the ref_update discussion is from the DeepSeekMath paper where they discussed the grpo math. but the question still remains!🤔

@qgallouedec
Copy link
Member

Don't bother with multi gpu, I'm go a test myself

I think we understand similarly. I'm wondering what the user would expect.
This soft update as implemented gives probably better results. But it doesn't match the paper.

Let me make some tests. I'll come back to you.

@shirinyamani
Copy link
Contributor Author

@qgallouedec Did you get to test this by any chance ? 🤔

@qgallouedec
Copy link
Member

Not yet, will do asap

@qgallouedec
Copy link
Member

Actually I don't have time to test unfortunately, but I think it's really worth:

  1. adding a param ref_model_sync_epochs to allow user reproduce precisely the method describe in the paper. Ideally, allow this value to be both int and and float in (0, 1)
  2. run some experiments to check if
  • the current default values make sense
  • does it give significantly different results

Do you want to handle 1. @shirinyamani?

In the meantime I'll merge this one.

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @shirinyamani!

@qgallouedec qgallouedec merged commit b2ae999 into huggingface:main Feb 4, 2025
13 checks passed
@shirinyamani
Copy link
Contributor Author

shirinyamani commented Feb 4, 2025

Actually I don't have time to test unfortunately, but I think it's really worth:

  1. adding a param ref_model_sync_epochs to allow user reproduce precisely the method describe in the paper. Ideally, allow this value to be both int and and float in (0, 1)

Do you want to handle 1. @shirinyamani?

@qgallouedec
Sure, for brainstorming purposes let's break down our options;

Goal: is to have a param like 0 < ref_model_sync_epochs < 1 that allow user to once in a while after X number of epochs (can be 0.2 of epoch) to update the ref_model ?

How to build:

Option 1): to override the current on_step_end method in the SyncRefModelCallback class to reflect what we want;
now it is like this;

class SyncRefModelCallback(TrainerCallback):
    def __init__(
        self,
        ref_model: Union[PreTrainedModel, torch.nn.Module],
        accelerator: Optional[Accelerator],
    ):
        self.accelerator = accelerator
        self.ref_model = ref_model

    @staticmethod
    def _sync_target_model(model, target_model, alpha):
        for target_param, copy_param in zip(target_model.parameters(), model.parameters()):
            target_param.data.mul_(1.0 - alpha).add_(copy_param.data, alpha=alpha)

    @staticmethod
    def sync_target_model(model, target_model, alpha):
        deepspeed_plugin = AcceleratorState().deepspeed_plugin
        if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3:
            with deepspeed.zero.GatheredParameters(
                list(model.parameters()) + list(target_model.parameters()), modifier_rank=0
            ):
                if deepspeed.comm.get_rank() == 0:
                    SyncRefModelCallback._sync_target_model(model, target_model, alpha)
        else:
            SyncRefModelCallback._sync_target_model(model, target_model, alpha)

    def on_step_end(self, args, state, control, **kwargs):
        model: PreTrainedModel = kwargs["model"]

        if self.ref_model is not None and state.global_step % args.ref_model_sync_steps == 0:
            if self.accelerator:
                model = self.accelerator.unwrap_model(model)
            self.sync_target_model(model, self.ref_model, args.ref_model_mixup_alpha)

with the changes would be sth like;

    def on_step_end(self, args, state, control, **kwargs):
        model: PreTrainedModel = kwargs["model"]

        # Calculate total steps per epoch
        steps_per_epoch = state.max_steps // args.num_train_epochs

        # Determine if we should sync based on ref_model_sync_epochs
        if isinstance(self.ref_model_sync_epochs, int):
            # Sync based on integer number of epochs
            should_sync = state.global_step % (self.ref_model_sync_epochs * steps_per_epoch) == 0
        elif isinstance(self.ref_model_sync_epochs, float):
            # Sync based on fraction of total epochs
            should_sync = (state.global_step / steps_per_epoch) % (self.ref_model_sync_epochs * args.num_train_epochs) == 0
        else:
            raise ValueError("ref_model_sync_epochs must be an int or a float")

        if self.ref_model is not None and should_sync:
            if self.accelerator:
                model = self.accelerator.unwrap_model(model)
            self.sync_target_model(model, self.ref_model, args.ref_model_mixup_alpha)

This might work if my understanding of on_step_end is correct! Since there was no doc string, my understanding from this method is that its like a check if the reference model should be synchronized based on ref_model_sync_steps and performs the synchronization if necessary.

Option 2:) is to add the ref_model_sync_epochs locally to the grpo training loss update ?

Thoughts ? 💭

@shirinyamani
Copy link
Contributor Author

shirinyamani commented Feb 5, 2025

One more Question for you; @qgallouedec

If the ref_model is getting updated so frequently, would it be same as not having ref_model at all ? 🤔

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants